{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 02MOD.ipynb\n",
    "\n",
    "show off my super neat solution\n",
    "\n",
    "05/27/2020"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = '3'\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import torch\n",
    "from torch import nn\n",
    "from torch.nn import functional as F\n",
    "from torch.utils.data import DataLoader, Dataset\n",
    "import pytorch_lightning as pl\n",
    "import random\n",
    "import warnings\n",
    "warnings.simplefilter(action='ignore', category=FutureWarning)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "tra_df = pd.read_pickle('extracted/train.pkl')\n",
    "val_df = pd.read_pickle('extracted/valid.pkl')\n",
    "test_df = pd.read_pickle('extracted/testB.pkl')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Dataset Preprocess"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "29 15 15\n",
      "91 30 33\n",
      "2999528\n",
      "2999043\n"
     ]
    }
   ],
   "source": [
    "def print_longest_token():\n",
    "    print(max(tra_df['query_tokenid'].str.len()),\\\n",
    "          max(val_df['query_tokenid'].str.len()),\\\n",
    "          max(test_df['query_tokenid'].str.len()))\n",
    "    \n",
    "def print_longest_roi():\n",
    "    print(max(tra_df['features'].apply(lambda x: x.shape[0])),\\\n",
    "          max(val_df['features'].apply(lambda x: x.shape[0])),\\\n",
    "          max(test_df['features'].apply(lambda x: x.shape[0])))\n",
    "print_longest_token()\n",
    "\n",
    "print_longest_roi()\n",
    "print(len(tra_df))\n",
    "tra_df_mask = (tra_df['features'].apply(lambda x: x.shape[0]) <= 40)\n",
    "tra_df = tra_df[tra_df_mask]\n",
    "print(len(tra_df))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def pad_token(x,l=30):\n",
    "    tokens = np.zeros(l,dtype=np.long)\n",
    "    attns = np.zeros(l,dtype=np.float)\n",
    "    tokens[:len(x)]=x\n",
    "    attns[:len(x)]=1\n",
    "    return tokens, attns\n",
    "\n",
    "def pad_rois(x,l=40):\n",
    "    rois = np.zeros([l,2048],dtype=np.float)\n",
    "    attns = np.zeros(l,dtype=np.float)\n",
    "    rois[:len(x),:]=x\n",
    "    attns[:len(x)]=1\n",
    "    return rois, attns"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Dataset Definition"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "class TrainSet(Dataset):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.df = tra_df\n",
    "    def __len__(self):\n",
    "        return len(self.df)\n",
    "    def __getitem__(self, idx):\n",
    "        row = self.df.iloc[idx]\n",
    "        tokens, attn1 = pad_token(row['query_tokenid'])\n",
    "        rois,   attn2 = pad_rois (row['features'])\n",
    "        return tokens, attn1, rois, attn2\n",
    "tra = TrainSet()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "496\n",
      "(30,) (30,) (30, 40, 2048) (30, 40) (30,)\n"
     ]
    }
   ],
   "source": [
    "class InferSet(Dataset):\n",
    "    def __init__(self,df):\n",
    "        super().__init__()\n",
    "        self.df = df\n",
    "        self.qids = df['query_id'].unique().tolist()\n",
    "    def __len__(self):\n",
    "        return len(self.qids)\n",
    "    def __getitem__(self, idx):\n",
    "        qid = self.qids[idx]\n",
    "        rows = self.df[self.df['query_id']==qid]\n",
    "        \n",
    "        fisrt_row = rows.iloc[0]\n",
    "        tokens, attn1 = pad_token(fisrt_row['query_tokenid'])\n",
    "        \n",
    "        rois_list = []\n",
    "        attn2_list = []\n",
    "        for _,row in rows.iterrows():\n",
    "            rois,   attn2 = pad_rois (row['features'])\n",
    "            rois_list.append(rois[None,:])\n",
    "            attn2_list.append(attn2[None,:])\n",
    "        rois = np.vstack(rois_list)\n",
    "        attn2 = np.vstack(attn2_list)\n",
    "        return tokens, attn1, rois, attn2, fisrt_row['query_id'], rows['product_id'].values\n",
    "val = InferSet(val_df)\n",
    "print(len(val))\n",
    "for i in val:\n",
    "    tokens, attn1, rois, attn2, qid, pid = i\n",
    "    print(tokens.shape, attn1.shape, rois.shape, attn2.shape, pid.shape)\n",
    "    break"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Metric Definition"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import json\n",
    "\n",
    "def dcg_at_k(r, k):\n",
    "    r = np.asfarray(r)[:k]\n",
    "    if r.size:\n",
    "        return r[0] + np.sum(r[1:] / np.log2(np.arange(3, r.size + 2)))\n",
    "    return 0.\n",
    "\n",
    "def get_ndcg(r, ref, k):\n",
    "    dcg_max = dcg_at_k(ref, k)\n",
    "    if not dcg_max:\n",
    "        return 0.\n",
    "    dcg = dcg_at_k(r, k)\n",
    "    return dcg / dcg_max\n",
    "\n",
    "class Valid():\n",
    "    def __init__(self):\n",
    "        self.result = {}\n",
    "        parsed_json = json.loads(open('dataset/valid_answer.json').read())\n",
    "        for k,v in parsed_json.items():\n",
    "            parsed_json[k]=[str(x) for x in v]\n",
    "        self.answer = parsed_json # both k and v are strings\n",
    "    def add_prediction(self, qid, pid, score):\n",
    "        qid,pid = str(qid),str(pid)\n",
    "        self.result.setdefault(qid, [])\n",
    "        self.result[qid].append([pid,score])\n",
    "    def cal_ndcg5(self):\n",
    "        mNDCG = 0.0\n",
    "        for qid in self.result.keys():\n",
    "            preds = self.result[qid] # our\n",
    "            gt = self.answer[qid] # keys\n",
    "            preds = sorted(preds, key=lambda x:x[1], reverse=True)\n",
    "            pred_vec = [1.0 if pid[0] in gt else 0.0 for pid in preds[:5]]\n",
    "            mNDCG += get_ndcg(pred_vec, [1.0] * 5, 5)\n",
    "        mNDCG /= len(self.result.keys())\n",
    "        return mNDCG"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Model Definition"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:transformers.file_utils:PyTorch version 1.4.0+cu100 available.\n"
     ]
    }
   ],
   "source": [
    "from transformers import BertModel, BertConfig\n",
    "bert_text = BertModel(BertConfig(num_hidden_layers=4))\n",
    "bert_img = BertModel(BertConfig(num_hidden_layers=4))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "class TextEncoder(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(TextEncoder, self).__init__()\n",
    "        self.bert_text = bert_text\n",
    "    def forward(self, tokens, attns):\n",
    "        sequence_output, pooled_output = self.bert_text(\n",
    "            input_ids=tokens, \n",
    "            attention_mask=attns)\n",
    "        return pooled_output\n",
    "\n",
    "class ImageEncoder(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(ImageEncoder, self).__init__()\n",
    "        self.bert_img = bert_img\n",
    "        self.roi_encoder = nn.Linear(2048, 768)\n",
    "    def forward(self, rois, attns):\n",
    "        embedding_output = self.bert_img.embeddings.LayerNorm(self.roi_encoder(rois))\n",
    "        extended_attention_mask = attns[:, None, None, :].to(dtype=embedding_output.dtype)\n",
    "        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0\n",
    "        encoder_outputs = self.bert_img.encoder(\n",
    "            embedding_output,\n",
    "            attention_mask=extended_attention_mask,\n",
    "            head_mask=[None]*100)\n",
    "        sequence_output = encoder_outputs[0]\n",
    "        pooled_output = self.bert_img.pooler(sequence_output)\n",
    "        return pooled_output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MyModel(pl.LightningModule):\n",
    "    def __init__(self, bs=512):\n",
    "        super(MyModel, self).__init__()\n",
    "        self.txt_enc = TextEncoder()\n",
    "        self.img_enc = ImageEncoder()\n",
    "        self.bs = bs\n",
    "        self.target = torch.tensor(np.diag([1.0]*bs)).cuda()\n",
    "    def forward(self, tokens, attn1, rois, attn2):\n",
    "        txt_emb = self.txt_enc(tokens, attn1)\n",
    "        img_emb = self.img_enc(rois, attn2)\n",
    "        sim = txt_emb.mm(img_emb.T)\n",
    "        return F.binary_cross_entropy_with_logits(sim, self.target)\n",
    "    def predict(self, tokens, attn1, rois, attn2):\n",
    "        # 1 query per pridiction\n",
    "        txt_emb = self.txt_enc(tokens, attn1)\n",
    "        img_emb = self.img_enc(rois.squeeze(), attn2.squeeze())\n",
    "        sim = txt_emb.mm(img_emb.T)\n",
    "        return sim\n",
    "    \n",
    "    # train stuff\n",
    "    def training_step(self, batch, batch_nb):\n",
    "        tokens, attn1, rois, attn2 = batch\n",
    "        loss = self(tokens, attn1, rois, attn2)\n",
    "        tensorboard_logs = {'train_loss': loss}\n",
    "        return {'loss': loss, 'log': tensorboard_logs}\n",
    "    def configure_optimizers(self):\n",
    "        return torch.optim.Adam(self.parameters(), lr=2e-05, eps=1e-08)\n",
    "    def train_dataloader(self):\n",
    "        return DataLoader(tra, batch_size=self.bs, shuffle=True, \n",
    "                          num_workers=16, pin_memory=True, drop_last=True)\n",
    "    \n",
    "    # valid and test stuff\n",
    "    def validation_step(self, batch, batch_nb):\n",
    "        tokens, attn1, rois, attn2, qid, pid = batch\n",
    "        y_hat = self.predict(tokens, attn1, rois, attn2)\n",
    "        return {'y_hat': y_hat, 'qid':qid, 'pid':pid}\n",
    "    def validation_epoch_end(self, outputs):\n",
    "        valid = Valid()\n",
    "        for o in outputs:\n",
    "            qid = o['qid'].cpu().numpy()[0]\n",
    "            pid = o['pid'].cpu().numpy()[0]\n",
    "            scores = o['y_hat'].cpu().numpy()[0]\n",
    "            for idx, p in enumerate(pid):\n",
    "                valid.add_prediction(qid, p, scores[idx])\n",
    "        tensorboard_logs = {'ndcg5': valid.cal_ndcg5()}\n",
    "        return {'val_loss': torch.tensor(1-valid.cal_ndcg5()), 'log': tensorboard_logs}\n",
    "    def val_dataloader(self):\n",
    "        return DataLoader(val, batch_size=1)\n",
    "model = MyModel()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:lightning:GPU available: True, used: True\n",
      "INFO:lightning:VISIBLE GPUS: 0\n",
      "INFO:lightning:Using 16bit precision.\n",
      "INFO:lightning:\n",
      "    | Name                                                         | Type              | Params\n",
      "-----------------------------------------------------------------------------------------------\n",
      "0   | txt_enc                                                      | TextEncoder       | 52 M  \n",
      "1   | txt_enc.bert_text                                            | BertModel         | 52 M  \n",
      "2   | txt_enc.bert_text.embeddings                                 | BertEmbeddings    | 23 M  \n",
      "3   | txt_enc.bert_text.embeddings.word_embeddings                 | Embedding         | 23 M  \n",
      "4   | txt_enc.bert_text.embeddings.position_embeddings             | Embedding         | 393 K \n",
      "5   | txt_enc.bert_text.embeddings.token_type_embeddings           | Embedding         | 1 K   \n",
      "6   | txt_enc.bert_text.embeddings.LayerNorm                       | LayerNorm         | 1 K   \n",
      "7   | txt_enc.bert_text.embeddings.dropout                         | Dropout           | 0     \n",
      "8   | txt_enc.bert_text.encoder                                    | BertEncoder       | 28 M  \n",
      "9   | txt_enc.bert_text.encoder.layer                              | ModuleList        | 28 M  \n",
      "10  | txt_enc.bert_text.encoder.layer.0                            | BertLayer         | 7 M   \n",
      "11  | txt_enc.bert_text.encoder.layer.0.attention                  | BertAttention     | 2 M   \n",
      "12  | txt_enc.bert_text.encoder.layer.0.attention.self             | BertSelfAttention | 1 M   \n",
      "13  | txt_enc.bert_text.encoder.layer.0.attention.self.query       | Linear            | 590 K \n",
      "14  | txt_enc.bert_text.encoder.layer.0.attention.self.key         | Linear            | 590 K \n",
      "15  | txt_enc.bert_text.encoder.layer.0.attention.self.value       | Linear            | 590 K \n",
      "16  | txt_enc.bert_text.encoder.layer.0.attention.self.dropout     | Dropout           | 0     \n",
      "17  | txt_enc.bert_text.encoder.layer.0.attention.output           | BertSelfOutput    | 592 K \n",
      "18  | txt_enc.bert_text.encoder.layer.0.attention.output.dense     | Linear            | 590 K \n",
      "19  | txt_enc.bert_text.encoder.layer.0.attention.output.LayerNorm | LayerNorm         | 1 K   \n",
      "20  | txt_enc.bert_text.encoder.layer.0.attention.output.dropout   | Dropout           | 0     \n",
      "21  | txt_enc.bert_text.encoder.layer.0.intermediate               | BertIntermediate  | 2 M   \n",
      "22  | txt_enc.bert_text.encoder.layer.0.intermediate.dense         | Linear            | 2 M   \n",
      "23  | txt_enc.bert_text.encoder.layer.0.output                     | BertOutput        | 2 M   \n",
      "24  | txt_enc.bert_text.encoder.layer.0.output.dense               | Linear            | 2 M   \n",
      "25  | txt_enc.bert_text.encoder.layer.0.output.LayerNorm           | LayerNorm         | 1 K   \n",
      "26  | txt_enc.bert_text.encoder.layer.0.output.dropout             | Dropout           | 0     \n",
      "27  | txt_enc.bert_text.encoder.layer.1                            | BertLayer         | 7 M   \n",
      "28  | txt_enc.bert_text.encoder.layer.1.attention                  | BertAttention     | 2 M   \n",
      "29  | txt_enc.bert_text.encoder.layer.1.attention.self             | BertSelfAttention | 1 M   \n",
      "30  | txt_enc.bert_text.encoder.layer.1.attention.self.query       | Linear            | 590 K \n",
      "31  | txt_enc.bert_text.encoder.layer.1.attention.self.key         | Linear            | 590 K \n",
      "32  | txt_enc.bert_text.encoder.layer.1.attention.self.value       | Linear            | 590 K \n",
      "33  | txt_enc.bert_text.encoder.layer.1.attention.self.dropout     | Dropout           | 0     \n",
      "34  | txt_enc.bert_text.encoder.layer.1.attention.output           | BertSelfOutput    | 592 K \n",
      "35  | txt_enc.bert_text.encoder.layer.1.attention.output.dense     | Linear            | 590 K \n",
      "36  | txt_enc.bert_text.encoder.layer.1.attention.output.LayerNorm | LayerNorm         | 1 K   \n",
      "37  | txt_enc.bert_text.encoder.layer.1.attention.output.dropout   | Dropout           | 0     \n",
      "38  | txt_enc.bert_text.encoder.layer.1.intermediate               | BertIntermediate  | 2 M   \n",
      "39  | txt_enc.bert_text.encoder.layer.1.intermediate.dense         | Linear            | 2 M   \n",
      "40  | txt_enc.bert_text.encoder.layer.1.output                     | BertOutput        | 2 M   \n",
      "41  | txt_enc.bert_text.encoder.layer.1.output.dense               | Linear            | 2 M   \n",
      "42  | txt_enc.bert_text.encoder.layer.1.output.LayerNorm           | LayerNorm         | 1 K   \n",
      "43  | txt_enc.bert_text.encoder.layer.1.output.dropout             | Dropout           | 0     \n",
      "44  | txt_enc.bert_text.encoder.layer.2                            | BertLayer         | 7 M   \n",
      "45  | txt_enc.bert_text.encoder.layer.2.attention                  | BertAttention     | 2 M   \n",
      "46  | txt_enc.bert_text.encoder.layer.2.attention.self             | BertSelfAttention | 1 M   \n",
      "47  | txt_enc.bert_text.encoder.layer.2.attention.self.query       | Linear            | 590 K \n",
      "48  | txt_enc.bert_text.encoder.layer.2.attention.self.key         | Linear            | 590 K \n",
      "49  | txt_enc.bert_text.encoder.layer.2.attention.self.value       | Linear            | 590 K \n",
      "50  | txt_enc.bert_text.encoder.layer.2.attention.self.dropout     | Dropout           | 0     \n",
      "51  | txt_enc.bert_text.encoder.layer.2.attention.output           | BertSelfOutput    | 592 K \n",
      "52  | txt_enc.bert_text.encoder.layer.2.attention.output.dense     | Linear            | 590 K \n",
      "53  | txt_enc.bert_text.encoder.layer.2.attention.output.LayerNorm | LayerNorm         | 1 K   \n",
      "54  | txt_enc.bert_text.encoder.layer.2.attention.output.dropout   | Dropout           | 0     \n",
      "55  | txt_enc.bert_text.encoder.layer.2.intermediate               | BertIntermediate  | 2 M   \n",
      "56  | txt_enc.bert_text.encoder.layer.2.intermediate.dense         | Linear            | 2 M   \n",
      "57  | txt_enc.bert_text.encoder.layer.2.output                     | BertOutput        | 2 M   \n",
      "58  | txt_enc.bert_text.encoder.layer.2.output.dense               | Linear            | 2 M   \n",
      "59  | txt_enc.bert_text.encoder.layer.2.output.LayerNorm           | LayerNorm         | 1 K   \n",
      "60  | txt_enc.bert_text.encoder.layer.2.output.dropout             | Dropout           | 0     \n",
      "61  | txt_enc.bert_text.encoder.layer.3                            | BertLayer         | 7 M   \n",
      "62  | txt_enc.bert_text.encoder.layer.3.attention                  | BertAttention     | 2 M   \n",
      "63  | txt_enc.bert_text.encoder.layer.3.attention.self             | BertSelfAttention | 1 M   \n",
      "64  | txt_enc.bert_text.encoder.layer.3.attention.self.query       | Linear            | 590 K \n",
      "65  | txt_enc.bert_text.encoder.layer.3.attention.self.key         | Linear            | 590 K \n",
      "66  | txt_enc.bert_text.encoder.layer.3.attention.self.value       | Linear            | 590 K \n",
      "67  | txt_enc.bert_text.encoder.layer.3.attention.self.dropout     | Dropout           | 0     \n",
      "68  | txt_enc.bert_text.encoder.layer.3.attention.output           | BertSelfOutput    | 592 K \n",
      "69  | txt_enc.bert_text.encoder.layer.3.attention.output.dense     | Linear            | 590 K \n",
      "70  | txt_enc.bert_text.encoder.layer.3.attention.output.LayerNorm | LayerNorm         | 1 K   \n",
      "71  | txt_enc.bert_text.encoder.layer.3.attention.output.dropout   | Dropout           | 0     \n",
      "72  | txt_enc.bert_text.encoder.layer.3.intermediate               | BertIntermediate  | 2 M   \n",
      "73  | txt_enc.bert_text.encoder.layer.3.intermediate.dense         | Linear            | 2 M   \n",
      "74  | txt_enc.bert_text.encoder.layer.3.output                     | BertOutput        | 2 M   \n",
      "75  | txt_enc.bert_text.encoder.layer.3.output.dense               | Linear            | 2 M   \n",
      "76  | txt_enc.bert_text.encoder.layer.3.output.LayerNorm           | LayerNorm         | 1 K   \n",
      "77  | txt_enc.bert_text.encoder.layer.3.output.dropout             | Dropout           | 0     \n",
      "78  | txt_enc.bert_text.pooler                                     | BertPooler        | 590 K \n",
      "79  | txt_enc.bert_text.pooler.dense                               | Linear            | 590 K \n",
      "80  | txt_enc.bert_text.pooler.activation                          | Tanh              | 0     \n",
      "81  | img_enc                                                      | ImageEncoder      | 54 M  \n",
      "82  | img_enc.bert_img                                             | BertModel         | 52 M  \n",
      "83  | img_enc.bert_img.embeddings                                  | BertEmbeddings    | 23 M  \n",
      "84  | img_enc.bert_img.embeddings.word_embeddings                  | Embedding         | 23 M  \n",
      "85  | img_enc.bert_img.embeddings.position_embeddings              | Embedding         | 393 K \n",
      "86  | img_enc.bert_img.embeddings.token_type_embeddings            | Embedding         | 1 K   \n",
      "87  | img_enc.bert_img.embeddings.LayerNorm                        | LayerNorm         | 1 K   \n",
      "88  | img_enc.bert_img.embeddings.dropout                          | Dropout           | 0     \n",
      "89  | img_enc.bert_img.encoder                                     | BertEncoder       | 28 M  \n",
      "90  | img_enc.bert_img.encoder.layer                               | ModuleList        | 28 M  \n",
      "91  | img_enc.bert_img.encoder.layer.0                             | BertLayer         | 7 M   \n",
      "92  | img_enc.bert_img.encoder.layer.0.attention                   | BertAttention     | 2 M   \n",
      "93  | img_enc.bert_img.encoder.layer.0.attention.self              | BertSelfAttention | 1 M   \n",
      "94  | img_enc.bert_img.encoder.layer.0.attention.self.query        | Linear            | 590 K \n",
      "95  | img_enc.bert_img.encoder.layer.0.attention.self.key          | Linear            | 590 K \n",
      "96  | img_enc.bert_img.encoder.layer.0.attention.self.value        | Linear            | 590 K \n",
      "97  | img_enc.bert_img.encoder.layer.0.attention.self.dropout      | Dropout           | 0     \n",
      "98  | img_enc.bert_img.encoder.layer.0.attention.output            | BertSelfOutput    | 592 K \n",
      "99  | img_enc.bert_img.encoder.layer.0.attention.output.dense      | Linear            | 590 K \n",
      "100 | img_enc.bert_img.encoder.layer.0.attention.output.LayerNorm  | LayerNorm         | 1 K   \n",
      "101 | img_enc.bert_img.encoder.layer.0.attention.output.dropout    | Dropout           | 0     \n",
      "102 | img_enc.bert_img.encoder.layer.0.intermediate                | BertIntermediate  | 2 M   \n",
      "103 | img_enc.bert_img.encoder.layer.0.intermediate.dense          | Linear            | 2 M   \n",
      "104 | img_enc.bert_img.encoder.layer.0.output                      | BertOutput        | 2 M   \n",
      "105 | img_enc.bert_img.encoder.layer.0.output.dense                | Linear            | 2 M   \n",
      "106 | img_enc.bert_img.encoder.layer.0.output.LayerNorm            | LayerNorm         | 1 K   \n",
      "107 | img_enc.bert_img.encoder.layer.0.output.dropout              | Dropout           | 0     \n",
      "108 | img_enc.bert_img.encoder.layer.1                             | BertLayer         | 7 M   \n",
      "109 | img_enc.bert_img.encoder.layer.1.attention                   | BertAttention     | 2 M   \n",
      "110 | img_enc.bert_img.encoder.layer.1.attention.self              | BertSelfAttention | 1 M   \n",
      "111 | img_enc.bert_img.encoder.layer.1.attention.self.query        | Linear            | 590 K \n",
      "112 | img_enc.bert_img.encoder.layer.1.attention.self.key          | Linear            | 590 K \n",
      "113 | img_enc.bert_img.encoder.layer.1.attention.self.value        | Linear            | 590 K \n",
      "114 | img_enc.bert_img.encoder.layer.1.attention.self.dropout      | Dropout           | 0     \n",
      "115 | img_enc.bert_img.encoder.layer.1.attention.output            | BertSelfOutput    | 592 K \n",
      "116 | img_enc.bert_img.encoder.layer.1.attention.output.dense      | Linear            | 590 K \n",
      "117 | img_enc.bert_img.encoder.layer.1.attention.output.LayerNorm  | LayerNorm         | 1 K   \n",
      "118 | img_enc.bert_img.encoder.layer.1.attention.output.dropout    | Dropout           | 0     \n",
      "119 | img_enc.bert_img.encoder.layer.1.intermediate                | BertIntermediate  | 2 M   \n",
      "120 | img_enc.bert_img.encoder.layer.1.intermediate.dense          | Linear            | 2 M   \n",
      "121 | img_enc.bert_img.encoder.layer.1.output                      | BertOutput        | 2 M   \n",
      "122 | img_enc.bert_img.encoder.layer.1.output.dense                | Linear            | 2 M   \n",
      "123 | img_enc.bert_img.encoder.layer.1.output.LayerNorm            | LayerNorm         | 1 K   \n",
      "124 | img_enc.bert_img.encoder.layer.1.output.dropout              | Dropout           | 0     \n",
      "125 | img_enc.bert_img.encoder.layer.2                             | BertLayer         | 7 M   \n",
      "126 | img_enc.bert_img.encoder.layer.2.attention                   | BertAttention     | 2 M   \n",
      "127 | img_enc.bert_img.encoder.layer.2.attention.self              | BertSelfAttention | 1 M   \n",
      "128 | img_enc.bert_img.encoder.layer.2.attention.self.query        | Linear            | 590 K \n",
      "129 | img_enc.bert_img.encoder.layer.2.attention.self.key          | Linear            | 590 K \n",
      "130 | img_enc.bert_img.encoder.layer.2.attention.self.value        | Linear            | 590 K \n",
      "131 | img_enc.bert_img.encoder.layer.2.attention.self.dropout      | Dropout           | 0     \n",
      "132 | img_enc.bert_img.encoder.layer.2.attention.output            | BertSelfOutput    | 592 K \n",
      "133 | img_enc.bert_img.encoder.layer.2.attention.output.dense      | Linear            | 590 K \n",
      "134 | img_enc.bert_img.encoder.layer.2.attention.output.LayerNorm  | LayerNorm         | 1 K   \n",
      "135 | img_enc.bert_img.encoder.layer.2.attention.output.dropout    | Dropout           | 0     \n",
      "136 | img_enc.bert_img.encoder.layer.2.intermediate                | BertIntermediate  | 2 M   \n",
      "137 | img_enc.bert_img.encoder.layer.2.intermediate.dense          | Linear            | 2 M   \n",
      "138 | img_enc.bert_img.encoder.layer.2.output                      | BertOutput        | 2 M   \n",
      "139 | img_enc.bert_img.encoder.layer.2.output.dense                | Linear            | 2 M   \n",
      "140 | img_enc.bert_img.encoder.layer.2.output.LayerNorm            | LayerNorm         | 1 K   \n",
      "141 | img_enc.bert_img.encoder.layer.2.output.dropout              | Dropout           | 0     \n",
      "142 | img_enc.bert_img.encoder.layer.3                             | BertLayer         | 7 M   \n",
      "143 | img_enc.bert_img.encoder.layer.3.attention                   | BertAttention     | 2 M   \n",
      "144 | img_enc.bert_img.encoder.layer.3.attention.self              | BertSelfAttention | 1 M   \n",
      "145 | img_enc.bert_img.encoder.layer.3.attention.self.query        | Linear            | 590 K \n",
      "146 | img_enc.bert_img.encoder.layer.3.attention.self.key          | Linear            | 590 K \n",
      "147 | img_enc.bert_img.encoder.layer.3.attention.self.value        | Linear            | 590 K \n",
      "148 | img_enc.bert_img.encoder.layer.3.attention.self.dropout      | Dropout           | 0     \n",
      "149 | img_enc.bert_img.encoder.layer.3.attention.output            | BertSelfOutput    | 592 K \n",
      "150 | img_enc.bert_img.encoder.layer.3.attention.output.dense      | Linear            | 590 K \n",
      "151 | img_enc.bert_img.encoder.layer.3.attention.output.LayerNorm  | LayerNorm         | 1 K   \n",
      "152 | img_enc.bert_img.encoder.layer.3.attention.output.dropout    | Dropout           | 0     \n",
      "153 | img_enc.bert_img.encoder.layer.3.intermediate                | BertIntermediate  | 2 M   \n",
      "154 | img_enc.bert_img.encoder.layer.3.intermediate.dense          | Linear            | 2 M   \n",
      "155 | img_enc.bert_img.encoder.layer.3.output                      | BertOutput        | 2 M   \n",
      "156 | img_enc.bert_img.encoder.layer.3.output.dense                | Linear            | 2 M   \n",
      "157 | img_enc.bert_img.encoder.layer.3.output.LayerNorm            | LayerNorm         | 1 K   \n",
      "158 | img_enc.bert_img.encoder.layer.3.output.dropout              | Dropout           | 0     \n",
      "159 | img_enc.bert_img.pooler                                      | BertPooler        | 590 K \n",
      "160 | img_enc.bert_img.pooler.dense                                | Linear            | 590 K \n",
      "161 | img_enc.bert_img.pooler.activation                           | Tanh              | 0     \n",
      "162 | img_enc.roi_encoder                                          | Linear            | 1 M   \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Selected optimization level O1:  Insert automatic casts around Pytorch functions and Tensor methods.\n",
      "\n",
      "Defaults for this optimization level are:\n",
      "enabled                : True\n",
      "opt_level              : O1\n",
      "cast_model_type        : None\n",
      "patch_torch_functions  : True\n",
      "keep_batchnorm_fp32    : None\n",
      "master_weights         : None\n",
      "loss_scale             : dynamic\n",
      "Processing user overrides (additional kwargs that are not None)...\n",
      "After processing overrides, optimization options are:\n",
      "enabled                : True\n",
      "opt_level              : O1\n",
      "cast_model_type        : None\n",
      "patch_torch_functions  : True\n",
      "keep_batchnorm_fp32    : None\n",
      "master_weights         : None\n",
      "loss_scale             : dynamic\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/amax/miniconda3/envs/ycx_light/lib/python3.8/site-packages/pytorch_lightning/utilities/warnings.py:18: UserWarning: The dataloader, val dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` in the `DataLoader` init to improve performance.\n",
      "  warnings.warn(*args, **kwargs)\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Validation sanity check', layout=Layout(flex='2'), max=5.…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "727cbb391be24b0fa111392d764f5b5f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=1.0), HTML(value='')), …"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/amax/miniconda3/envs/ycx_light/lib/python3.8/site-packages/pytorch_lightning/utilities/warnings.py:18: RuntimeWarning: Displayed epoch numbers in the progress bar start from \"1\" until v0.6.x, but will start from \"0\" in v0.8.0.\n",
      "  warnings.warn(*args, **kwargs)\n"
     ]
    }
   ],
   "source": [
    "# model = MyModel.load_from_checkpoint('lightning_logs/version_4/checkpoints/epoch=13.ckpt')\n",
    "from pytorch_lightning.callbacks import ModelCheckpoint\n",
    "trainer = pl.Trainer(gpus=1, precision=16, val_check_interval=0.5, \n",
    "                     checkpoint_callback=ModelCheckpoint(filepath='./checkpoint1/' + '{epoch:02d}-{val_loss:.2f}',\n",
    "                                                         verbose=True, monitor='val_loss', mode='min'))\n",
    "trainer.fit(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
